{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TTS Inference\n",
    "\n",
    "This notebook can be used to generate audio samples using either NeMo's pretrained models or after training NeMo TTS models. This script currently uses a two step inference procedure. First, a model is used to generate a mel spectrogram from text. Second, a model is used to generate audio from a mel spectrogram.\n",
    "\n",
    "Currently supported models are:\n",
    "Mel Spectrogram Generators:\n",
    "- Tacotron 2\n",
    "- Glow-TTS\n",
    "\n",
    "Audio Generators\n",
    "- WaveGlow\n",
    "- SqueezeWave\n",
    "- UniGlow\n",
    "- MelGAN\n",
    "- HiFiGAN\n",
    "- Two Stage Models\n",
    "    - Griffin-Lim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "> Copyright 2020 NVIDIA. All Rights Reserved.\n",
    "> \n",
    "> Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "> you may not use this file except in compliance with the License.\n",
    "> You may obtain a copy of the License at\n",
    "> \n",
    ">     http://www.apache.org/licenses/LICENSE-2.0\n",
    "> \n",
    "> Unless required by applicable law or agreed to in writing, software\n",
    "> distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "> See the License for the specific language governing permissions and\n",
    "> limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# # If you're using Google Colab and not running locally, uncomment and run this cell.\n",
    "# !apt-get install sox libsndfile1 ffmpeg\n",
    "# !pip install wget unidecode\n",
    "# BRANCH = 'r1.0.0rc1'\n",
    "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[tts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "supported_spec_gen = [\"tacotron2\", \"glow_tts\"]\n",
    "supported_audio_gen = [\"waveglow\", \"squeezewave\", \"uniglow\", \"melgan\", \"hifigan\", \"two_stages\"]\n",
    "\n",
    "print(\"Choose one of the following spectrogram generators:\")\n",
    "print([model for model in supported_spec_gen])\n",
    "spectrogram_generator = input()\n",
    "print(\"Choose one of the following audio generators:\")\n",
    "print([model for model in supported_audio_gen])\n",
    "audio_generator = input()\n",
    "\n",
    "assert spectrogram_generator in supported_spec_gen\n",
    "assert audio_generator in supported_audio_gen\n",
    "\n",
    "if audio_generator==\"two_stages\":\n",
    "    print(\"Choose one of the following mel-to-spec convertor:\")\n",
    "    # supported_mel2spec = [\"psuedo_inverse\", \"encoder_decoder\"] - No encoder_decoder checkpoint atm\n",
    "    supported_mel2spec = [\"psuedo_inverse\"]\n",
    "    print([model for model in supported_mel2spec])\n",
    "    mel2spec = input()\n",
    "    print(\"Choose one of the following linear spectrogram vocoders:\")\n",
    "    # supported_linear_vocoders = [\"griffin_lim\", \"degli\"]  - No deep_gli checkpoint atm\n",
    "    supported_linear_vocoders = [\"griffin_lim\"]\n",
    "    print([model for model in supported_linear_vocoders])\n",
    "    linvocoder = input()\n",
    "    assert mel2spec in supported_mel2spec\n",
    "    assert linvocoder in supported_linear_vocoders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model checkpoints\n",
    "\n",
    "Note: For best quality with Glow TTS, please update the glow tts yaml file with the path to cmudict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf, open_dict\n",
    "import torch\n",
    "from nemo.collections.asr.parts import parsers\n",
    "from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder\n",
    "\n",
    "\n",
    "def load_spectrogram_model():\n",
    "    override_conf = None\n",
    "    if spectrogram_generator == \"tacotron2\":\n",
    "        from nemo.collections.tts.models import Tacotron2Model\n",
    "        pretrained_model = \"tts_en_tacotron2\"       \n",
    "    elif spectrogram_generator == \"glow_tts\":\n",
    "        from nemo.collections.tts.models import GlowTTSModel\n",
    "        pretrained_model = \"tts_en_glowtts\"\n",
    "        import wget\n",
    "        from pathlib import Path\n",
    "        if not Path(\"cmudict-0.7b\").exists():\n",
    "            filename = wget.download(\"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b\")\n",
    "            filename = str(Path(filename).resolve())\n",
    "        else:\n",
    "            filename = str(Path(\"cmudict-0.7b\").resolve())\n",
    "        conf = SpectrogramGenerator.from_pretrained(pretrained_model, return_config=True)\n",
    "        if \"params\" in conf.parser:\n",
    "            conf.parser.params.cmu_dict_path = filename\n",
    "        else:\n",
    "            conf.parser.cmu_dict_path = filename\n",
    "        override_conf = conf\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    model = SpectrogramGenerator.from_pretrained(pretrained_model, override_config_path=override_conf)\n",
    "    return model\n",
    "\n",
    "\n",
    "def load_vocoder_model():\n",
    "    RequestPseudoInverse = False\n",
    "    TwoStagesModel = False\n",
    "    \n",
    "    if audio_generator == \"waveglow\":\n",
    "        from nemo.collections.tts.models import WaveGlowModel\n",
    "        pretrained_model = \"tts_waveglow_88m\"\n",
    "    elif audio_generator == \"squeezewave\":\n",
    "        from nemo.collections.tts.models import SqueezeWaveModel\n",
    "        pretrained_model = \"tts_squeezewave\"\n",
    "    elif audio_generator == \"uniglow\":\n",
    "        from nemo.collections.tts.models import UniGlowModel\n",
    "        pretrained_model = \"tts_uniglow\"\n",
    "    elif audio_generator == \"melgan\":\n",
    "        from nemo.collections.tts.models import MelGanModel\n",
    "        pretrained_model = \"tts_melgan\"\n",
    "    elif audio_generator == \"hifigan\":\n",
    "        from nemo.collections.tts.models import HifiGanModel\n",
    "        pretrained_model = \"tts_hifigan\"\n",
    "    elif audio_generator == \"two_stages\":\n",
    "        from nemo.collections.tts.models import TwoStagesModel\n",
    "        cfg = {'linvocoder':  {'_target_': 'nemo.collections.tts.models.two_stages.GriffinLimModel',\n",
    "                             'cfg': {'n_iters': 64, 'n_fft': NFFT, 'l_hop': 256}},\n",
    "               'mel2spec': {'_target_': 'nemo.collections.tts.models.two_stages.MelPsuedoInverseModel',\n",
    "                           'cfg': {'sampling_rate': SAMPLE_RATE, 'n_fft': NFFT, \n",
    "                                   'mel_fmin': 0, 'mel_fmax': FMAX, 'mel_freq': NMEL}}}\n",
    "        model = TwoStagesModel(cfg)\n",
    "        if mel2spec == \"encoder_decoder\":\n",
    "            from nemo.collections.tts.models.ed_mel2spec import EDMel2SpecModel\n",
    "            pretrained_mel2spec_model = \"EncoderDecoderMelToSpec-22050Hz\"\n",
    "            mel2spec_model = EDMel2SpecModel.from_pretrained(pretrained_mel2spec_model)\n",
    "            model.set_mel_to_spec_model(mel2spec_model)\n",
    "\n",
    "        if linvocoder == \"degli\":\n",
    "            from nemo.collections.tts.models.degli import DegliModel\n",
    "            pretrained_linvocoder_model = \"DeepGriffinLim-22050Hz\"\n",
    "            linvocoder_model = DegliModel.from_pretrained(pretrained_linvocoder_model)\n",
    "            model.set_linear_vocoder(linvocoder_model)\n",
    "            \n",
    "        TwoStagesModel = True\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    if not TwoStagesModel:\n",
    "        model = Vocoder.from_pretrained(pretrained_model)\n",
    "    return model\n",
    "\n",
    "spec_gen = load_spectrogram_model().cuda()\n",
    "vocoder = load_vocoder_model().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer(spec_gen_model, vocder_model, str_input):\n",
    "    with torch.no_grad():\n",
    "        parsed = spec_gen.parse(str_input)\n",
    "        spectrogram = spec_gen.generate_spectrogram(tokens=parsed)\n",
    "        audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n",
    "    if isinstance(spectrogram, torch.Tensor):\n",
    "        spectrogram = spectrogram.to('cpu').numpy()\n",
    "    if len(spectrogram.shape) == 3:\n",
    "        spectrogram = spectrogram[0]\n",
    "    if isinstance(audio, torch.Tensor):\n",
    "        audio = audio.to('cpu').numpy()\n",
    "    return spectrogram, audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_to_generate = input(\"Input what you want the model to say: \")\n",
    "spec, audio = infer(spec_gen, vocoder, text_to_generate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show Audio and Spectrogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython.display as ipd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from matplotlib.pyplot import imshow\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "ipd.Audio(audio, rate=22050)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "imshow(spec, origin=\"lower\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}